/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "op_dtype_selection_strategy_allow_mix_precision.h"
#include "ops_store/ops_kernel_manager.h"

namespace fe {
OpDtypeSelectionStrategyAllowMixPrecision::OpDtypeSelectionStrategyAllowMixPrecision(
    const std::string engine_name,
    FormatDtypeQuerierPtr format_dtype_querier_ptr,
    OpDtypeMixPrecisionMatcherPtr op_dtype_mixed_precision_matcher_ptr,
    OpDtypeRiseMatcherPtr op_dtype_rise_matcher_ptr,
    OpDtypeReduceMatcherPtr op_dtype_reduce_matcher_ptr)
    : OpDtypeSeletionStrategyBase(format_dtype_querier_ptr),
      engine_name_(engine_name),
      op_dtype_mixed_precision_matcher_ptr_(op_dtype_mixed_precision_matcher_ptr),
      op_dtype_rise_matcher_ptr_(op_dtype_rise_matcher_ptr),
      op_dtype_reduce_matcher_ptr_(op_dtype_reduce_matcher_ptr) {}

OpDtypeSelectionStrategyAllowMixPrecision::~OpDtypeSelectionStrategyAllowMixPrecision() {}

Status OpDtypeSelectionStrategyAllowMixPrecision::GetPrecisionPolicy(const OpKernelInfoPtr& op_kernel_info_ptr,
                                                                     PrecisionPolicy& precision_policy) {
  FE_CHECK_NOTNULL(op_kernel_info_ptr);
  precision_policy = op_kernel_info_ptr->GetOpStoreInfo().precision_policy;
  return SUCCESS;
}

Status OpDtypeSelectionStrategyAllowMixPrecision::RunForOpInWhiteList(FormatDtypeSelectionBasicInfo& basic_info) {
  FE_CHECK_NOTNULL(basic_info.node);
  auto cur_op_desc_ptr = basic_info.node->GetOpDesc();
  FE_CHECK_NOTNULL(cur_op_desc_ptr);
  ge::DataType origin_dtype = basic_info.cur_tensor_desc->GetDataType();
  bool dtype_float_flag = (origin_dtype == ge::DT_FLOAT || origin_dtype == ge::DT_FLOAT16);
  /* If the op is in white list but its original data type is not
   * fp32 or fp16, we will use allow fp32_to_fp16 mode to select data type. */
  if (!dtype_float_flag) {
    FE_LOGD("The data type of tensor %u of op %s is not fp32 or fp16!", basic_info.index,
            cur_op_desc_ptr->GetName().c_str());
    FE_LOGD("Try to match original data type %u.", origin_dtype);
    DefaultSelector default_select_mode(new OpDtypeSelectionStrategyDefaultMode(format_dtype_querier_ptr_,
                                                                                op_dtype_rise_matcher_ptr_));
    return default_select_mode->Run(basic_info);
  } else {
    /* Only pick fp16 as its dtype, if it does not support fp16, we will
     * pick the higher precision version. */
    vector<ge::DataType> input_or_output_dtype_vec;
    if (format_dtype_querier_ptr_->GetSupportDataTypes(basic_info.op_kernel_info_ptr, basic_info.tensor_kernel_info_ptr,
                                                       *(cur_op_desc_ptr.get()),
                                                       input_or_output_dtype_vec) != SUCCESS) {
      REPORT_FE_ERROR("[GraphOpt][DtypeJdg][MixedPrcsn][Op %s type %s] Fail to get the support data_types.",
                      cur_op_desc_ptr->GetName().c_str(), cur_op_desc_ptr->GetType().c_str());

      return FAILED;
    }
    Status ret = op_dtype_mixed_precision_matcher_ptr_->Match(input_or_output_dtype_vec, origin_dtype,
                                                              basic_info.matched_index_vec);
    if (ret != SUCCESS) {
      /* We allow the node in white list using fp32, so here we just report
       * a warning log and return success. */
      FE_LOGW("[GraphOpt][DtypeJdg][MixedPrecision][Op %s type %s]is in white list but it doesn't support fp16!",
              cur_op_desc_ptr->GetName().c_str(), cur_op_desc_ptr->GetType().c_str());
      return SUCCESS;
    }
  }
  return SUCCESS;
}

Status OpDtypeSelectionStrategyAllowMixPrecision::RunForOpInBlackList(FormatDtypeSelectionBasicInfo& basic_info) {
  DefaultSelector default_select_mode(new OpDtypeSelectionStrategyDefaultMode(format_dtype_querier_ptr_,
                                                                              op_dtype_rise_matcher_ptr_));
  Status ret = default_select_mode->Run(basic_info);
  string node_name = basic_info.node->GetName();
  auto dtype = basic_info.cur_tensor_desc->GetDataType();
  if (ret != SUCCESS) {
    REPORT_FE_ERROR(
        "[GraphOpt][DtypeJdg][RunBlackListOp] Op %s is in blacklist but doesn't support it's original dtype %s",
        node_name.c_str(),
        ge::TypeUtils::DataTypeToSerialString(dtype).c_str());
    if (dtype == ge::DT_FLOAT || dtype == ge::DT_FLOAT16) {
      REPORT_FE_ERROR("[GraphOpt][DtypeJdg][RunBlackListOp] Op %s should not be configured as blacklist op",
                      node_name.c_str());
    }
  }
  return ret;
}

bool OpDtypeSelectionStrategyAllowMixPrecision::IsOpFp16ToFp32Cast(const ge::OpDescPtr& cur_op_desc_ptr,
                                                                   const uint32_t& fahter_out_anchor_index) {
  string op_type = cur_op_desc_ptr->GetType();
  if (op_type == CAST) {
    /* If Cast is in Black list, we need to check whether it's  */
    PrecisionPolicy precision_policy = GRAY;
    Status ret = QueryPrecisionPolicy(cur_op_desc_ptr, precision_policy);
    /* If Cast is in black list, we cannot jump over it. So we return false to
     * consider it as normal Cast. */
    if (ret == SUCCESS && precision_policy != BLACK) {
      auto father_output_desc = cur_op_desc_ptr->GetOutputDescPtr(fahter_out_anchor_index);
      auto father_input_desc = cur_op_desc_ptr->GetInputDescPtr(0);
      if (father_output_desc->GetDataType() == ge::DT_FLOAT && father_input_desc->GetDataType() == ge::DT_FLOAT16) {
        FE_LOGD("Father of %s is %u Cast", cur_op_desc_ptr->GetName().c_str(), precision_policy);
        return true;
      }
    }
    FE_LOGD("Father of %s is BLACK Cast", cur_op_desc_ptr->GetName().c_str());
  }
  return false;
}

Status OpDtypeSelectionStrategyAllowMixPrecision::QueryPrecisionPolicy(const ge::OpDescPtr &op_desc_ptr,
                                                                       PrecisionPolicy &precision_policy) {
  FE_CHECK_NOTNULL(op_desc_ptr);
  auto op_kernel =
          OpsKernelManager::Instance(engine_name_).GetOpKernelInfoByOpType(EN_IMPL_HW_TBE, op_desc_ptr->GetType());
  if (op_kernel == nullptr) {
    REPORT_FE_ERROR("[GraphOpt][DtypeJdg][QryPrecisPolicy] op %s is not found in tbe built-in store.",
                    op_desc_ptr->GetType().c_str());
    return FAILED;
  }
  precision_policy = op_kernel->GetOpStoreInfo().precision_policy;
  return SUCCESS;
}

void OpDtypeSelectionStrategyAllowMixPrecision::MatchForGray(const string &cur_op_desc_type,
    const string &cur_op_desc_name, const vector<ge::DataType> &op_kernel_dtype_vec,
    ge::DataType father_output_dtype, FormatDtypeSelectionBasicInfo& basic_info) {
  FE_LOGD("Op[name=%s,type=%s]: match father dtype, the expected dtype is [%u].", cur_op_desc_name.c_str(),
          cur_op_desc_type.c_str(), father_output_dtype);
  Status match_father_dtype_res =
      op_dtype_rise_matcher_ptr_->Match(op_kernel_dtype_vec, father_output_dtype, basic_info.matched_index_vec);
  if (match_father_dtype_res != SUCCESS) {
    FE_LOGD("Precision loss is allowed, try to match low precision dtype.");
    match_father_dtype_res =
        op_dtype_reduce_matcher_ptr_->Match(op_kernel_dtype_vec, father_output_dtype, basic_info.matched_index_vec);
  }
  if (match_father_dtype_res == SUCCESS) {
    FE_LOGD("Op[name=%s,type=%s]: match the father dtype success, some matched dtypes in op kernel have been found.",
            cur_op_desc_name.c_str(), cur_op_desc_type.c_str());
    FE_LOGD("The size of dtype is [%zu], the size of matched index is [%zu].", op_kernel_dtype_vec.size(),
            basic_info.matched_index_vec.size());
  } else {
    FE_LOGD("Op[name=%s,type=%s]: no matched the dtype %u, matchedIndexVec remain the same.",
            cur_op_desc_name.c_str(), cur_op_desc_type.c_str(), father_output_dtype);
  }
}

Status OpDtypeSelectionStrategyAllowMixPrecision::RunForOpInGrayList(FormatDtypeSelectionBasicInfo& basic_info) {
  auto cur_op_desc_ptr = basic_info.node->GetOpDesc();
  FE_CHECK_NOTNULL(cur_op_desc_ptr);
  std::string cur_op_desc_name = cur_op_desc_ptr->GetName();
  std::string cur_op_desc_type = cur_op_desc_ptr->GetType();
  ge::InDataAnchorPtr in_data_anchor;
  bool has_no_father = false;

  CheckHasNoFather(basic_info.is_input, static_cast<int32_t>(basic_info.index), basic_info.node,
                   in_data_anchor, has_no_father);

  /* 1. If the node is Gray list does not have predecessors, we just match the
   * dtype with its original dtype. */
  if (has_no_father) {
    FE_LOGD("Op[name=%s,type=%s]: the op does not have a father node on input [%u]. Match with its original dtype.",
            cur_op_desc_name.c_str(), cur_op_desc_type.c_str(), basic_info.index);
    AllowFp32ToFp16Selector allow_fp32_to_fp16_selector(
        new OpDtypeSelectionStrategyAllowFp32ToFp16(format_dtype_querier_ptr_,
                                                    op_dtype_rise_matcher_ptr_,
                                                    op_dtype_reduce_matcher_ptr_));
    return allow_fp32_to_fp16_selector->Run(basic_info);
  }

  auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  uint32_t fahter_out_anchor_index = (uint32_t)peer_out_anchor->GetIdx();
  ge::OpDescPtr father_op_desc = peer_out_anchor->GetOwnerNode()->GetOpDesc();
  ge::NodePtr father_node = peer_out_anchor->GetOwnerNode();

  FE_LOGD("Op[name=%s,type=%s]:match format and dtype for the input %u between father Op[name=%s,type=%s] and this op.",
          cur_op_desc_name.c_str(), cur_op_desc_type.c_str(), basic_info.index, father_op_desc->GetName().c_str(),
          father_op_desc->GetType().c_str());

  /* 1.1 If the father node is Cast, and the data type of output of Cast is fp32,
   * we will try to skip this cast and match the dtype in front of cast */
  if (IsOpFp16ToFp32Cast(father_op_desc, fahter_out_anchor_index)) {
    in_data_anchor = father_node->GetInDataAnchor(0);
    FE_CHECK_NOTNULL(in_data_anchor);
    auto father_out_anchor = in_data_anchor->GetPeerOutAnchor();
    FE_CHECK_NOTNULL(father_out_anchor);
    father_node = father_out_anchor->GetOwnerNode();
    FE_CHECK_NOTNULL(father_node);
    father_op_desc = father_node->GetOpDesc();
    FE_CHECK_NOTNULL(father_op_desc);
    fahter_out_anchor_index = father_out_anchor->GetIdx();
  }

  /* 2. Match all supported data type with father's output data type. */
  auto father_output_desc = father_op_desc->GetOutputDescPtr(fahter_out_anchor_index);
  ge::DataType father_output_dtype = father_output_desc->GetDataType();
  vector<ge::DataType> op_kernel_dtype_vec;
  if (format_dtype_querier_ptr_->GetSupportDataTypes(basic_info.op_kernel_info_ptr, basic_info.tensor_kernel_info_ptr,
                                                     *(cur_op_desc_ptr.get()), op_kernel_dtype_vec) != SUCCESS) {
    REPORT_FE_ERROR("[GraphOpt][DtypeJdg][RunGrayListOp] Fail to get the support data_types, return FAILED.");
    return FAILED;
  }
  MatchForGray(cur_op_desc_type, cur_op_desc_name, op_kernel_dtype_vec, father_output_dtype, basic_info);
  return SUCCESS;
}

/* In this mode we will match the dtype fp16 first. If the */
Status OpDtypeSelectionStrategyAllowMixPrecision::Run(FormatDtypeSelectionBasicInfo& basic_info) {
  FE_CHECK_NOTNULL(basic_info.node);
  auto cur_op_desc_ptr = basic_info.node->GetOpDesc();
  FE_CHECK_NOTNULL(cur_op_desc_ptr);
  std::string cur_op_desc_name = cur_op_desc_ptr->GetName();
  std::string cur_op_desc_type = cur_op_desc_ptr->GetType();

  PrecisionPolicy precision_policy;
  if (GetPrecisionPolicy(basic_info.op_kernel_info_ptr, precision_policy) != SUCCESS) {
    FE_LOGD("Op[name=%s,type=%s]: Failed to get precision policy.", cur_op_desc_name.c_str(), cur_op_desc_type.c_str());
    return FAILED;
  }

  FE_LOGD("Op[name=%s,type=%s]: precision policy is %u.", cur_op_desc_name.c_str(), cur_op_desc_type.c_str(),
          precision_policy);
  if (precision_policy == BLACK) {
    /* If the ops is in black list, we must use its original data type */
    return RunForOpInBlackList(basic_info);
  } else if (precision_policy == WHITE) {
    return RunForOpInWhiteList(basic_info);
  } else {
    return RunForOpInGrayList(basic_info);
  }
}
}
